# --------------------------------------------------------
# Pytorch Yolov2
# Licensed under The MIT License [see LICENSE for details]
# Written by Jingru Tan
# --------------------------------------------------------
import torch


def box_ious(boxes1, boxes2):
    """
    Implement the intersection over union (IoU) between boxes1 and boxes2
    with shape of (x1, y1, x2, y2)

    Arguments:
        boxes1: (N, 4), the first set of boxes
        boxes2: (K, 4), the second set of boxes

    Returns:
        ious: (N, K), IoUs between boxes
    """

    N = boxes1.size(0)
    K = boxes2.size(0)

    # When torch.max() takes tensor of different shape as arguments, it will
    # broadcasting them.
    xi1 = torch.max(boxes1[:, 0].view(N, 1), boxes2[:, 0].view(1, K))
    yi1 = torch.max(boxes1[:, 1].view(N, 1), boxes2[:, 1].view(1, K))
    xi2 = torch.min(boxes1[:, 2].view(N, 1), boxes2[:, 2].view(1, K))
    yi2 = torch.min(boxes1[:, 3].view(N, 1), boxes2[:, 3].view(1, K))

    # We want to compare the compare the value with 0 elementwise. However,
    # we can't simply feed int 0, because it will invoke the function
    # `torch(max, dim=int)`, which is not what we want. To feed a tensor 0 of
    # same type and device with boxes1 and boxes2, we use `tensor.new().fill_(0)`

    iw = torch.max(xi2 - xi1, boxes1.new(1).fill_(0))
    ih = torch.max(yi2 - yi1, boxes1.new(1).fill_(0))

    inter = iw * ih

    boxes1_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    boxes2_area = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    boxes1_area = boxes1_area.view(N, 1)
    boxes2_area = boxes2_area.view(1, K)

    union_area = boxes1_area + boxes2_area - inter

    ious = inter / union_area
    return ious


def xxyy2xywh(boxes):
    """
    Convert the box (x1, y1, x2, y2) encoding format to (c_x, c_y, w, h) format

    Arguments:
        boxes: (N, 4), N boxes of (x1, y1, x2, y2)

    Returns:
        xywh_boxes: (N, 4), N boxes of (c_x, c_y, w, h)
    """

    c_x = (boxes[:, 2] + boxes[:, 0]) / 2
    c_y = (boxes[:, 3] + boxes[:, 1]) / 2
    w = boxes[:, 2] - boxes[:, 0]
    h = boxes[:, 3] - boxes[:, 1]

    c_x = c_x.view(-1, 1)
    c_y = c_y.view(-1, 1)
    w = w.view(-1, 1)
    h = h.view(-1, 1)

    xywh_boxes = torch.cat([c_x, c_y, w, h], dim=1)
    return xywh_boxes


def xywh2xxyy(boxes):
    """
    Convert the box encoding format form (c_x, c_y, w, h) to (x1, y1, x2, y2)

    Arguments:
        boxes: (N, 4), N boxes of (c_x, c_y, w, h)

    Returns:
        xxyy_boxes: (N, 4), N boxes of (x1, y1, x2, y2)
    """

    x1 = boxes[:, 0] - (boxes[:, 2]) / 2
    y1 = boxes[:, 1] - (boxes[:, 3]) / 2
    x2 = boxes[:, 0] + (boxes[:, 2]) / 2
    y2 = boxes[:, 1] + (boxes[:, 3]) / 2

    x1 = x1.view(-1, 1)
    y1 = y1.view(-1, 1)
    x2 = x2.view(-1, 1)
    y2 = y2.view(-1, 1)

    xxyy_boxes = torch.cat([x1, y1, x2, y2], dim=1)
    return xxyy_boxes


def box_transform(boxes1, boxes2):
    """
    Calculate the delta values σ(t_x), σ(t_y), exp(t_w), exp(t_h) used for
    transforming boxes1 to boxes2

    Arguments:
        boxes1: (N, 4), the first set of boxes (c_x, c_y, w, h)
        boxes2: (N, 4), the second set of boxes (c_x, c_y, w, h)

    Returns:
        deltas: (N, 4), delta values (t_x, t_y, t_w, t_h), used for
                transforming boxes to reference boxes
    """

    t_x = boxes2[:, 0] - boxes1[:, 0]
    t_y = boxes2[:, 1] - boxes1[:, 1]
    t_w = boxes2[:, 2] / boxes1[:, 2]
    t_h = boxes2[:, 3] / boxes1[:, 3]

    t_x = t_x.view(-1, 1)
    t_y = t_y.view(-1, 1)
    t_w = t_w.view(-1, 1)
    t_h = t_h.view(-1, 1)

    # σ(t_x), σ(t_y), exp(t_w), exp(t_h)
    deltas = torch.cat([t_x, t_y, t_w, t_h], dim=1)
    return deltas


def box_transform_inv(boxes, deltas):
    """
    apply deltas to box to generate predicted boxes

    Arguments:
         boxes: (N, 4), N boxes of (c_x, c_y, w, h)
        deltas: (N, 4), N deltas of (σ(t_x), σ(t_y), exp(t_w), exp(t_h))

    Returns:
        pred_boxes: (N, 4), N predicted boxes of (c_x, c_y, w, h)
    """

    c_x = boxes[:, 0] + deltas[:, 0]
    c_y = boxes[:, 1] + deltas[:, 1]
    w = boxes[:, 2] * deltas[:, 2]
    h = boxes[:, 3] * deltas[:, 3]

    c_x = c_x.view(-1, 1)
    c_y = c_y.view(-1, 1)
    w = w.view(-1, 1)
    h = h.view(-1, 1)

    pred_boxes = torch.cat([c_x, c_y, w, h], dim=-1)
    return pred_boxes


def generate_all_anchors(anchors, H, W):
    """
    Generate dense anchors given grid defined by (H,W)

    Arguments:
        anchors: (num_anchors, 2), pre-defined anchors (pw, ph) on each cell
              H: int, grid height
              W: int, grid width

    Returns:
        all_anchors: (H * W * num_anchors, 4), dense grid anchors (c_x, c_y, w, h)
    """

    # number of anchors per cell
    A = anchors.size(0)

    # number of cells
    K = H * W

    # Anchors are organized in HxW order
    shift_x, shift_y = torch.meshgrid(
        torch.arange(0, W), torch.arange(0, H), indexing="xy"
    )

    # shift_x is a long tensor, c_x is a float tensor
    c_x = shift_x.float()
    c_y = shift_y.float()

    # tensor of shape (h * w, 2) with value of (cx, cy)
    centers = torch.cat([c_x.view(-1, 1), c_y.view(-1, 1)], dim=-1)

    # add anchors width and height to centers
    all_anchors = torch.cat(
        [centers.view(K, 1, 2).expand(K, A, 2), anchors.view(1, A, 2).expand(K, A, 2)],
        dim=-1,
    )

    all_anchors = all_anchors.view(-1, 4)
    return all_anchors
